{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Import Python Modules" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "from sklearn.model_selection import train_test_split\n", "import copy" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Generate data for Regression\n", "\n", "In the tree of depth 1 example that we looked at, the model we fit had two terminal nodes and a single feature. He're we're going to fit a model with eight terminal nodes and two features." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "n = 2000 #Number of observations in the training set\n", "\n", "theta = [4, 4, 10, 12]\n", "c = [15, 40, 5, 30, 10]\n", "\n", "x0 = np.random.uniform(0, 16, n)\n", "x1 = np.random.uniform(0, 16, n)\n", "\n", "x = np.array([x0,x1]).reshape((-1,2))\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "def generateY(x, splitPoints, theta, sd):\n", " \n", " \n", " if x[0] > theta[0]:\n", " if x[1] > theta[2]:\n", " y = np.random.normal(c[0], sd, 1)\n", " else:\n", " if x[0] <= theta[3]:\n", " y = np.random.normal(c[1], sd, 1) \n", " else:\n", " y = np.random.normal(c[2], sd, 1)\n", " else: \n", " if x[1] <= theta[1]:\n", " y = np.random.normal(c[3], sd, 1)\n", " else: \n", " y = np.random.normal(c[4], sd, 1)\n", " \n", " return y[0]\n", "\n", "y = [generateY(i, c, theta, 3) for i in x]" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", " | X0 | \n", "X1 | \n", "Y | \n", "
---|---|---|---|
1615 | \n", "13.695913 | \n", "0.961214 | \n", "10.334377 | \n", "
1518 | \n", "5.895645 | \n", "9.269740 | \n", "36.857414 | \n", "
1276 | \n", "9.927904 | \n", "15.067658 | \n", "17.251418 | \n", "
1971 | \n", "5.740117 | \n", "0.098567 | \n", "36.432726 | \n", "
149 | \n", "3.203203 | \n", "8.391481 | \n", "14.792564 | \n", "